# -*- coding: utf-8 -*-
"""
Created on Tue Dec 01 21:05:42 2015

@author: Tim
"""
import matplotlib.pyplot as P
from pylab import *                ## import scientific database
close("all")                       ## close all windows
from netCDF4 import Dataset
import numpy as np
from mpl_toolkits.basemap import Basemap
from matplotlib.font_manager import FontProperties
from mpl_toolkits.axes_grid.inset_locator import inset_axes
from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes
from mpl_toolkits.axes_grid1.inset_locator import mark_inset
from wrf import getvar, interplevel, to_np, get_basemap, latlon_coords, CoordPair, vertcross
import glob
import cmocean
import cmocean.cm as cmo
import numpy.ma as ma
from descartes import PolygonPatch
from matplotlib.patches import Polygon
import cartopy.crs as ccrs
import matplotlib.ticker as mticker

from plot_marshall_functions import ext2bbox, load_baseimage_any
from plot_marshall_functions import gjson_feature_geometry, gjson2poly
from plot_marshall_functions import refine_dom, read_wrffr
from plot_marshall_functions import pyproj_wrf2wgs, cartopy_defwrfproj, mplext_outer

g = 9.81

ll_lona = np.zeros(2)
ll_lata = np.zeros(2)
ul_lona = np.zeros(2)
ul_lata = np.zeros(2)
ur_lona = np.zeros(2)
ur_lata = np.zeros(2)
lr_lona = np.zeros(2)
lr_lata = np.zeros(2)

for i in np.arange(1,2):
    if i == 0:
        wrf_files = '/glade/scratch/tjuliano/fire/sagehen/WRF-Fire-LEAPHI/WRF/test/marshall/2021123012_maria/wrfinput_d01'
    elif i == 1:
        wrf_files = '/glade/scratch/tjuliano/fire/sagehen/WRF-Fire-LEAPHI/WRF/test/marshall/2021123012_maria/wrfinput_d02'
        wrf_files2 = '/glade/scratch/tjuliano/fire/marshall/joinwrfh/1min/big2_maria_run9_add_spotting/fire/wrfout_d02_2021-12-31_01:45:00'

    print ('Reading WRF files...')

    s1 = Dataset(wrf_files)
    s2 = Dataset(wrf_files2)

    i1 = 690
    i2 = 1050
    j1 = 930
    j2 = 1200

    lat2 = s2.variables['FXLAT'][0,j1:j2,i1:i2].data
    lon2 = s2.variables['FXLONG'][0,j1:j2,i1:i2].data
    nfuel_cat = s1.variables['NFUEL_CAT'][0,j1:j2,i1:i2]

    ll_lon = lon2[0,0]
    ll_lat = lat2[0,0]
    ul_lon = lon2[-1,0]
    ul_lat = lat2[-1,0]
    ur_lon = lon2[-1,-1]
    ur_lat = lat2[-1,-1]
    lr_lon = lon2[0,-1]
    lr_lat = lat2[0,-1]

#ll_lon = -105.26
#ll_lat = 39.92
#ur_lon = -105.12
#ur_lat = 40.0

#print (ll_lon, ll_lat, ur_lon, ur_lat)

# center point of overall domain
ref_lat = (ll_lat + ur_lat)/2.
ref_lon = -(abs(ll_lon) + abs(ur_lon))/2.
#ref_lon = 8.5

print (ref_lat, ref_lon)

# renumber fuel cats for plotting
# NB
nfuel_cat[nfuel_cat==14] = 114
# GR
nfuel_cat[nfuel_cat==1] = 101
nfuel_cat[nfuel_cat==2] = 102
nfuel_cat[nfuel_cat==3] = 103
# SH
nfuel_cat[nfuel_cat==4] = 104
nfuel_cat[nfuel_cat==5] = 105
nfuel_cat[nfuel_cat==6] = 106
nfuel_cat[nfuel_cat==7] = 107
# TL
nfuel_cat[nfuel_cat==8] = 108
nfuel_cat[nfuel_cat==9] = 109
nfuel_cat[nfuel_cat==10] = 110
# SB
nfuel_cat[nfuel_cat==11] = 111
nfuel_cat[nfuel_cat==12] = 112
nfuel_cat[nfuel_cat==13] = 113


# NB
nfuel_cat[nfuel_cat==114] = 1
# GR
nfuel_cat[nfuel_cat==101] = 2
nfuel_cat[nfuel_cat==102] = 3
nfuel_cat[nfuel_cat==103] = 4
# SH
nfuel_cat[nfuel_cat==104] = 5
nfuel_cat[nfuel_cat==105] = 6
nfuel_cat[nfuel_cat==106] = 7
nfuel_cat[nfuel_cat==107] = 8
# TL
nfuel_cat[nfuel_cat==108] = 9
nfuel_cat[nfuel_cat==109] = 10
nfuel_cat[nfuel_cat==110] = 11
# SB
nfuel_cat[nfuel_cat==111] = 12
nfuel_cat[nfuel_cat==112] = 13
nfuel_cat[nfuel_cat==113] = 14

colors = [(0.4, 0.7607843137254902, 0.6470588235294118),
              (0.9882352941176471, 0.5529411764705883, 0.3843137254901961),
              (0.5529411764705883, 0.6274509803921569, 0.796078431372549),
              (0.9058823529411765, 0.5411764705882353, 0.7647058823529411),
              (0.6509803921568628, 0.8470588235294118, 0.32941176470588235),
              (1.0, 0.8509803921568627, 0.1843137254901961),
              (0.8980392156862745, 0.7686274509803922, 0.5803921568627451)]
n_bin = 4 # Discretizes the interpolation into bins
n_bins_ranges = [1,2,5,9,11]
cmap_name = 'my_list'
cm = matplotlib.colors.LinearSegmentedColormap.from_list(cmap_name, colors, N=n_bin)
norm = matplotlib.colors.BoundaryNorm(n_bins_ranges, len(n_bins_ranges))
bounds_gr=np.linspace(1,14,14)

# create map subplots
# 1000 MB
#fig, ax = plt.subplots(111)
#ax = fig.add_axes([0.05,0.05,0.9,0.9])
#m = Basemap(llcrnrlon=ll_lon,llcrnrlat=ll_lat,urcrnrlon=ur_lon,urcrnrlat=ur_lat,lon_0=ref_lon,lat_0=ref_lat,projection='lcc',resolution='h')

proj = ccrs.PlateCarree()
f = plt.figure(figsize=(9,5))
ax = plt.subplot(111,projection=proj)

my_cmap = cmo.deep
my_cmap.set_under('lightskyblue')

c1 = plt.imshow(nfuel_cat,vmin=1,vmax=11,cmap=cm,aspect='auto',norm=norm,interpolation='nearest',origin='lower',extent=[ll_lon,ur_lon,ll_lat,ur_lat])

# customize map
#m.drawparallels(np.arange(60.,94.,2.0),labels=[True,False,False,False],fontsize=28,linewidth=3.0)
#m.drawmeridians(np.arange(-130.,-100.,2.0),labels=[False,False,False,True],fontsize=28,linewidth=3.0)

# --------------------------------------------------------------------
# --------------------------------------------------------------------
# Read fire perim json file

perim1 = gjson_feature_geometry(filename='marshall.json')

#print (type(perim1),np.shape(perim1))
#print (perim1)

poly = np.squeeze(np.array(perim1[0]['geometry']['coordinates']))
#print (perim1[0]['geometry'])
#poly = perim1['geometry']['coordinates']
#if 'features' in perim1.keys():
#    poly = [feat['geometry']['coordinates'] for feat in perim1['features']]
#poly = flatlist(poly)
#poly = poly[0]
#print (poly)

for i in np.arange(len(poly)):
    p = Polygon(poly[i], facecolor='none', edgecolor='magenta',lw=1.5)
    ax.add_patch(p)

#print (poly[0])

#if len(poly) == 1:
#    poly = poly[0]
#    mappoly = dict(type='Polygon', coordinates=[[list(prjtrans(*i)) for i in poly]])
#else:
#    mappoly = dict(type='MultiPolygon',
#                   coordinates=[[[prjtrans(*i) for i in mp[0]] for mp in poly]])

# Observed perimeter
#pp1 = [
#      ax.add_patch(PolygonPatch(ii, fc='None', ec='magenta', zorder=2, lw=3.))
#      for ii in poly
#      ]

gl = ax.gridlines(crs=ccrs.PlateCarree(), draw_labels=True, linewidth=1.0, linestyle='--', color='k', zorder=4)
gl.xlocator = mticker.FixedLocator([-105.24, -105.20, -105.16])
gl.ylocator = mticker.FixedLocator([39.94, 39.96, 39.98])
gl.xlabel_style = {'size': 14, 'rotation':315, 'ha':'left'}
gl.ylabel_style = {'size': 14}
gl.left_labels = True
gl.right_labels = False
gl.top_labels = False
gl.bottom_labels = True

plt.tight_layout()
plt.subplots_adjust(right=0.9)
#cbar_ax = fig.add_axes([0.92, 0.075, 0.02, 0.8])
#cbar = plt.colorbar(c1,cax=cbar_ax,ticks=[1.5,3.5,7,10])
cbar = plt.colorbar(c1,ticks=[1.5,3.5,7,10])
cbar.ax.set_yticklabels(['NB','GR','SH','TL'])
cbar.set_label('Fuel Category',size=24,rotation=270,labelpad=30)
cbar.ax.tick_params(labelsize=18)
#cbar.set_ticks([1500,2000,2500,3000,3500,4000])


#plt.tight_layout()
#plt.subplots_adjust(right=0.775)
plt.savefig('marshall_fuels.png',dpi=250,bbox_inches='tight')
